{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "77448887-ed71-48e1-bf0d-2ff499d0c7ca",
   "metadata": {},
   "source": [
    "# Lesson 1 - Semantic Search"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c3ea165",
   "metadata": {},
   "source": [
    "Welcome to Lesson 1. \n",
    "\n",
    "To access the `requirement.txt` file, go to `File` and click on `Open`.\n",
    " \n",
    "I hope you enjoy this course!"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "750cd7dd",
   "metadata": {},
   "source": [
    "### Import the Needed Packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "809aa032-d737-450d-aafa-e32bfba9d8f8",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30698fb9-4709-4088-9905-9ccb4efd5e09",
   "metadata": {
    "height": 166
   },
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "from sentence_transformers import SentenceTransformer\n",
    "from pinecone import Pinecone, ServerlessSpec\n",
    "from DLAIUtils import Utils\n",
    "import DLAIUtils\n",
    "\n",
    "import os\n",
    "import time\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ab484bb-3bfb-4c52-a5bd-bcbe4a7a63d2",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ba92fc2d",
   "metadata": {},
   "source": [
    "### Load the Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce859e4b-9b50-4f53-b357-28d3e3872c87",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "dataset = load_dataset('quora', split='train[240000:290000]')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "356d4112-fa51-4092-9841-8b266e3b6a2c",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "dataset[:5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "332d1241-61ae-4d09-bf46-52081c133c0c",
   "metadata": {
    "height": 132
   },
   "outputs": [],
   "source": [
    "questions = []\n",
    "for record in dataset['questions']:\n",
    "    questions.extend(record['text'])\n",
    "question = list(set(questions))\n",
    "print('\\n'.join(questions[:10]))\n",
    "print('-' * 50)\n",
    "print(f'Number of questions: {len(questions)}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ab3c0402",
   "metadata": {},
   "source": [
    "### Check cuda and Setup the model\n",
    "\n",
    "**Note**: \"Checking cuda\" refers to checking if you have access to GPUs (faster compute). In this course, we are using CPUs. So, you might notice some code cells taking a little longer to run.\n",
    "\n",
    "We are using *all-MiniLM-L6-v2* sentence-transformers model that maps sentences to a 384 dimensional dense vector space."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5fb67759-ab38-4472-bfb0-4a56d1c05955",
   "metadata": {
    "height": 81
   },
   "outputs": [],
   "source": [
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "if device != 'cuda':\n",
    "    print('Sorry no cuda.')\n",
    "model = SentenceTransformer('all-MiniLM-L6-v2', device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d90ec5ec-5397-4ed5-8163-7a901b6ecb0c",
   "metadata": {
    "height": 64
   },
   "outputs": [],
   "source": [
    "query = 'which city is the most populated in the world?'\n",
    "xq = model.encode(query)\n",
    "xq.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1780a189",
   "metadata": {},
   "source": [
    "### Setup Pinecone"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3e3a94e-127f-4667-a9ae-7a17d7304ee6",
   "metadata": {
    "height": 43
   },
   "outputs": [],
   "source": [
    "utils = Utils()\n",
    "PINECONE_API_KEY = utils.get_pinecone_api_key()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17a75eca-60f0-478b-bdcf-b68732c1545d",
   "metadata": {
    "height": 200
   },
   "outputs": [],
   "source": [
    "pinecone = Pinecone(api_key=PINECONE_API_KEY)\n",
    "INDEX_NAME = utils.create_dlai_index_name('dl-ai')\n",
    "\n",
    "if INDEX_NAME in [index.name for index in pinecone.list_indexes()]:\n",
    "    pinecone.delete_index(INDEX_NAME)\n",
    "print(INDEX_NAME)\n",
    "pinecone.create_index(name=INDEX_NAME, \n",
    "    dimension=model.get_sentence_embedding_dimension(), \n",
    "    metric='cosine',\n",
    "    spec=ServerlessSpec(cloud='aws', region='us-west-2'))\n",
    "\n",
    "index = pinecone.Index(INDEX_NAME)\n",
    "print(index)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d88d9424",
   "metadata": {},
   "source": [
    "### Create Embeddings and Upsert to Pinecone"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea551303-aa81-41cd-adc5-dc9ea8072397",
   "metadata": {
    "height": 352
   },
   "outputs": [],
   "source": [
    "batch_size=200\n",
    "vector_limit=10000\n",
    "\n",
    "questions = question[:vector_limit]\n",
    "\n",
    "import json\n",
    "\n",
    "for i in tqdm(range(0, len(questions), batch_size)):\n",
    "    # find end of batch\n",
    "    i_end = min(i+batch_size, len(questions))\n",
    "    # create IDs batch\n",
    "    ids = [str(x) for x in range(i, i_end)]\n",
    "    # create metadata batch\n",
    "    metadatas = [{'text': text} for text in questions[i:i_end]]\n",
    "    # create embeddings\n",
    "    xc = model.encode(questions[i:i_end])\n",
    "    # create records list for upsert\n",
    "    records = zip(ids, xc, metadatas)\n",
    "    # upsert to Pinecone\n",
    "    index.upsert(vectors=records)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6153920a-f4c4-420e-9790-262dd2299fc6",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "index.describe_index_stats()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "71135418",
   "metadata": {},
   "source": [
    "### Run Your Query"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28b20b81-4782-4ce2-aec3-9576c7779f2e",
   "metadata": {
    "height": 115
   },
   "outputs": [],
   "source": [
    "# small helper function so we can repeat queries later\n",
    "def run_query(query):\n",
    "  embedding = model.encode(query).tolist()\n",
    "  results = index.query(top_k=10, vector=embedding, include_metadata=True, include_values=False)\n",
    "  for result in results['matches']:\n",
    "    print(f\"{round(result['score'], 2)}: {result['metadata']['text']}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c4244d6-7be0-4ee1-a36a-0d586f0555f7",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "run_query('which city has the highest population in the world?')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f85b20fe-2328-47eb-84cd-bf3b27c6d4aa",
   "metadata": {
    "height": 47
   },
   "outputs": [],
   "source": [
    "query = 'how do i make chocolate cake?'\n",
    "run_query(query)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
